/**
 * Copyright 2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "c/ddk/graph/graph.h"
#include "c/ddk/graph/context.h"

#include <memory>
#include <string>

#include "resource_manager.h"
#include "framework/infra/log/log.h"
#include "graph/graph.h"
#include "graph/operator.h"
#include "framework/graph/debug/ge_log.h"

GraphHandle HIAI_IR_GraphCreate(ResMgrHandle resMgr, const char* name)
{
    if (resMgr == nullptr) {
        FMK_LOGE("resMgr is nullptr");
        return nullptr;
    }
    BasePtr graph = std::make_shared<ge::Graph>(std::string(name));
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    resMgrPtr->StoreSrcPtr(graph);
    return graph.get();
}


HIAI_Status HIAI_IR_SetInputs(ResMgrHandle resMgr, ConstGraphHandle graph, ConstOpHandle inputs[], uint32_t inputNum)
{
    if (resMgr == nullptr || graph == nullptr || inputs == nullptr) {
        FMK_LOGE("resMgr or graph or inputs is nullptr");
        return HIAI_FAILURE;
    }

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    std::vector<ge::Operator> inNodes;
    for (size_t i = 0; i < inputNum; ++i) {
        // input数组由用户进行创建，里面的Op通过resMgr管理
        BasePtr inputBasePtr = resMgrPtr->GetSrcPtr(inputs[i]);
        if (inputBasePtr == nullptr) {
            FMK_LOGE("get operator ptr failed");
            return HIAI_FAILURE;
        }
        std::shared_ptr<ge::Operator> inputPtr = std::dynamic_pointer_cast<ge::Operator>(inputBasePtr);
        inNodes.push_back(*inputPtr);
    }

    // 根据裸指针获取C++ Base指针
    BasePtr graphBasePtr = resMgrPtr->GetSrcPtr(graph);
    if (graphBasePtr == nullptr) {
        FMK_LOGE("get graph base ptr failed");
        return HIAI_FAILURE;
    }
    // Base指针转换成Graph指针
    std::shared_ptr<ge::Graph> graphPtr = std::dynamic_pointer_cast<ge::Graph>(graphBasePtr);
    // 调用C++接口
    graphPtr->SetInputs(inNodes);

    return HIAI_SUCCESS;
}

HIAI_Status HIAI_IR_SetOutputs(ResMgrHandle resMgr, ConstGraphHandle graph, ConstOpHandle outputs[], uint32_t outputNum)
{
    if (resMgr == nullptr || graph == nullptr || outputs == nullptr) {
        FMK_LOGE("resMgr or graph or outputs is nullptr");
        return HIAI_FAILURE;
    }

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    std::vector<ge::Operator> inNodes;
    for (size_t i = 0; i < outputNum; ++i) {
        // output数组由用户进行创建，里面的Op通过resMgr管理
        BasePtr output_base_ptr = resMgrPtr->GetSrcPtr(outputs[i]);
        if (output_base_ptr == nullptr) {
            FMK_LOGE("get operator ptr failed");
            return HIAI_FAILURE;
        }
        std::shared_ptr<ge::Operator> output_ptr = std::dynamic_pointer_cast<ge::Operator>(output_base_ptr);
        inNodes.push_back(*output_ptr);
    }

    // 根据裸指针获取C++ Base指针
    BasePtr graphBasePtr = resMgrPtr->GetSrcPtr(graph);
    if (graphBasePtr == nullptr) {
        FMK_LOGE("get graph base ptr failed");
        return HIAI_FAILURE;
    }
    // Base指针转换成Graph指针
    std::shared_ptr<ge::Graph> graphPtr = std::dynamic_pointer_cast<ge::Graph>(graphBasePtr);
    // 调用C++接口
    graphPtr->SetOutputs(inNodes);

    return HIAI_SUCCESS;
}